# Alexander F. Gazmararian
# agazmararian@gmail.com
# Last updated: 2024-02-14

library(tidyverse)
library(tidylog)
library(here)
library(jsonlite)
library(openai)
library(progressr)

#' Unified logging function that supports multiple usage patterns
#' Supports two usage patterns:
#' 1. log_message(type, msg) - e.g., log_message("INFO", "Starting process")
#' 2. log_message(format, ...) - e.g., log_message("Processing %d items", count)
#' Also supports file logging when log_file variable exists in parent environment
#' @param type_or_format The log type (for pattern 1) or format string (for pattern 2)
#' @param msg_or_args The message (for pattern 1) or format arguments (for pattern 2)
log_message <- function(type_or_format, msg_or_args = NULL, ...) {
  timestamp <- format(Sys.time(), "%Y-%m-%d %H:%M:%S")
  
  # Very defensive function to handle all cases
  tryCatch({
    # Check if this is the old pattern: log_message("INFO", "message")
    if (!is.null(msg_or_args) && length(list(...)) == 0 && 
        type_or_format %in% c("INFO", "WARN", "ERROR", "DEBUG")) {
      # Pattern 1: log_message(type, msg)
      formatted_msg <- sprintf("%s: %s", type_or_format, msg_or_args)
    } else {
      # Pattern 2: log_message(format, ...) - like the runner scripts
      if (is.null(msg_or_args) && length(list(...)) == 0) {
        # Simple message without formatting
        formatted_msg <- as.character(type_or_format)
      } else {
        # Try to format with sprintf
        args <- c(list(msg_or_args), list(...))
        formatted_msg <- do.call(sprintf, c(list(type_or_format), args))
      }
    }
    
    # Write to console with timestamp [[memory:3073899]]
    message(sprintf("[%s] %s", timestamp, formatted_msg))
    
    # Write to log file if log_file variable exists in calling environment
    if (exists("log_file", envir = parent.frame()) && 
        !is.null(get("log_file", envir = parent.frame()))) {
      log_file_path <- get("log_file", envir = parent.frame())
      if (file.exists(dirname(log_file_path))) {
        cat(sprintf("[%s] %s\n", timestamp, formatted_msg), 
            file = log_file_path, append = TRUE)
      }
    }
    
  }, error = function(e) {
    # Fallback: just print the original message [[memory:3073899]]
    message(sprintf("[%s] Logging error, printing directly: %s", 
                   timestamp, as.character(type_or_format)))
  })
}

#' Safe GPT API call with retries and error handling
#' @param model The GPT model to use
#' @param messages List of message objects for the conversation
#' @param temperature Temperature parameter for response randomness
#' @param max_retries Maximum number of retry attempts
#' @param sleep_base Base time for exponential backoff
#' @return API response content or NA if all retries fail
safe_gpt_call <- function(model, messages, temperature = 0, max_retries = 3, sleep_base = 2) {
  for (attempt in 1:max_retries) {
    result <- tryCatch({
      response <- create_chat_completion(
        model = model,
        messages = messages,
        temperature = temperature
      )
      if (is.null(response) || is.null(response$choices) || length(response$choices) == 0) {
        NULL
      } else {
        response$choices$message.content
      }
    }, error = function(e) {
      log_message("ERROR", sprintf("API error on attempt %d: %s", attempt, e$message))
      if (grepl("rate limit|429", e$message, ignore.case = TRUE)) {
        Sys.sleep(sleep_base ^ attempt) # Exponential backoff
      } else {
        Sys.sleep(1)
      }
      NULL
    })
    if (!is.null(result)) return(result)
  }
  log_message("ERROR", "All retries failed")
  NA_character_
}

#' Process and save batch results
#' @param results Data frame of batch results
#' @param output_dir Output directory path
#' @param batch_num Batch number
#' @param date_prefix Whether to add date prefix to output directory
#' Safely download XML files with error handling
#' @param url URL to download from
#' @param dest_file Destination file path
#' @return Logical indicating success
download_xml_safely <- function(url, dest_file) {
    tryCatch({
        download.file(url, dest_file, mode = "wb")
        message(sprintf("Successfully downloaded XML from %s", url))
        TRUE
    }, error = function(e) {
        warning(sprintf("Failed to download from %s: %s", url, e$message))
        FALSE
    })
}

#' Save batch results
#' @param results Data frame of batch results
#' @param output_dir Output directory path
#' @param batch_num Batch number
#' @param date_prefix Whether to add date prefix to output directory
save_batch_results <- function(results, output_dir, batch_num, date_prefix = TRUE) {
  if (date_prefix) {
    output_dir <- file.path(output_dir, format(Sys.Date(), "%Y%m%d"))
  }
  
  if (!dir.exists(output_dir)) {
    dir.create(output_dir, recursive = TRUE)
    log_message("INFO", sprintf("Created directory: %s", output_dir))
  }
  
  saveRDS(
    results,
    file.path(output_dir, sprintf("batch_%03d.rds", batch_num))
  )
  log_message("INFO", sprintf("Saved batch %d", batch_num))
}

#' Create batches from input data
#' @param data Vector of items to batch
#' @param batch_size Size of each batch
#' @return List of batches
create_batches <- function(data, batch_size) {
  split(data, ceiling(seq_along(data) / batch_size))
}

#' Check and setup OpenAI API key
#' @return TRUE if key is set, stops execution if not
check_api_key <- function() {
  api_key <- Sys.getenv("OPENAI_API_KEY")
  if (api_key == "") {
    stop(paste0(
      "OPENAI_API_KEY not set.\n",
      "For REPLICATION: Cached annotations should be in data/cache/annotations/\n",
      "  If missing, ensure the cache directory was properly copied.\n",
      "For NEW ANNOTATIONS: Add your API key to .Renviron file:\n",
      "  OPENAI_API_KEY=your-key-here"
    ))
  }
  Sys.setenv(OPENAI_API_KEY = api_key)
  TRUE
}

#' Combine batch results into final dataset
#' @param output_dir Directory containing batch files
#' @param pattern Pattern to match batch files
#' @return Combined data frame of all batches
combine_batch_results <- function(output_dir, pattern = "batch_\\d+\\.rds$") {
  files <- list.files(output_dir, pattern = pattern, full.names = TRUE)
  if (length(files) == 0) {
    stop("No batch files found in ", output_dir)
  }
  
  results <- purrr::map(files, readRDS) %>% dplyr::bind_rows()
  log_message("INFO", sprintf("Combined %d batch files", length(files)))
  results
} 